import os
import configparser
import json
from tqdm import tqdm

from utils.logger import Logger
from llm.llm_wrapper import LLMWrapper
from llm.auxiliary import Auxiliary
from llm.llm_const import (
    MODEL_ID_SILICONFLOW_DEEPSEEK_V3,
    MODEL_ID_SILICONFLOW_QWEN3_14B,
    MODEL_ID_GITEE_QWEN3_32B,
    MODEL_ID_CLAUDE_SONNET_4,
    MODEL_ID_SILICONFLOW_BGE_M3,
    MODEL_ID_LLAMA3_8B,
)

project_root = os.path.dirname(__file__)
logger = Logger(__name__, 'dev')

def read_config(local_config_path, default_config_path):
    config = configparser.ConfigParser()

    if os.path.exists(local_config_path):
        config.read(local_config_path)
        print(f'Loaded configuration from {local_config_path}')
    else:
        config.read(default_config_path)
        print(f'Loaded configuration from {default_config_path}')

    return config

test_config = read_config(
    os.path.join(project_root, 'local_config.ini'),
    os.path.join(project_root, 'default_config.ini'),
)

def _init_llm(model_id, api_key, model_save_path=None, model_cache_path=None, token=None):
    llm_config = {
        'model_id': model_id,
        'model_save_path': model_save_path,
        'model_cache_path': model_cache_path,
        'api_key': api_key,
        'token': token,
    }
    return LLMWrapper(config=llm_config, logger=logger)

huggingface_cache_dir = test_config.get('BASE', 'huggingface-cache')
huggingface_save_dir = test_config.get('BASE', 'huggingface-save')
huggingface_token = test_config.get('API_KEY', 'huggingface')
siliconflow_api_key = test_config.get('API_KEY', 'siliconflow')
claude_api_key = test_config.get('API_KEY', 'claude')
llm_type = test_config.get('EXP', 'llm')
embedding = test_config.get('EXP', 'embedding')

model_save_path = os.path.join(huggingface_save_dir, 'model')
model_cache_path = os.path.join(huggingface_cache_dir, 'model')

token = None
if llm_type == "ds":
    model_id = MODEL_ID_SILICONFLOW_DEEPSEEK_V3
    model_api_key = siliconflow_api_key
elif llm_type == "qwen":
    model_id = MODEL_ID_SILICONFLOW_QWEN3_14B
    model_api_key = siliconflow_api_key
elif llm_type == "claude":
    model_id = MODEL_ID_CLAUDE_SONNET_4
    model_api_key = claude_api_key
elif llm_type == "llama":
    model_id = MODEL_ID_LLAMA3_8B
    model_api_key = None
    token = huggingface_token

if embedding == "bge_m3":
    embedding_id = MODEL_ID_SILICONFLOW_BGE_M3
    embedding_api_key = siliconflow_api_key

llm = _init_llm(
    model_id,
    model_api_key,
    model_save_path,
    model_cache_path,
    token,
)
llm.init()
assets_dir = os.path.join(project_root, 'assets')
dataset = test_config.get('EXP', 'dataset')
cache_dir = os.path.join(project_root, 'cache', llm_type, dataset)
output_dir = os.path.join(project_root, 'output', llm_type, dataset)
model_name = test_config.get('EXP', 'model_name')
if model_name.lower() == 'none':
    model_name = None

config = {
    'cache_dir': cache_dir,
    'output_dir': output_dir,
    'api_generate_model_id': model_id,
    'api_embedding_model_id': embedding_id,
    'generate_api_key': model_api_key,
    'embedding_api_key': embedding_api_key,
    'model_name': model_name,
}

auxiliary = Auxiliary(logger, config)

from attack.baseline import Baseline
from attack.dra import DRA
from attack.art_prompt import ArtPrompt
from attack.flip_attack import FlipAttack
from attack.pair import PAIR
from attack.sata import SATA
from attack.pass_attack import PASS

if dataset == "adv_bench":
    dataset_path = os.path.join(assets_dir, "category_adv_bench_test.jsonl")
elif dataset == "jbb":
    dataset_path = os.path.join(assets_dir, "jbb.jsonl")

attacks = [
    "baseline", 
    "dra", 
    "art_prompt", 
    "flip_attack", 
    "pair", 
    "sata", 
    "pass"
]

def run_exp():
    items = []
    with open(dataset_path, 'r', encoding='utf-8') as f:
        for line in f:
            info = json.loads(line)
            idx = info['idx']
            query = info['query']
            target = info['target']
            category = info['category']

            if isinstance(category, dict):
                primary_category = category['primary_category']
                secondary_category = category['secondary_category']
                category_str = f"{primary_category}: {secondary_category}"
            else:
                category_str = str(category)

            items.append({
                "query": query,
                "target": target,
                "category": category_str
            })

    # items = items[:5]
    os.makedirs(output_dir, exist_ok=True)
    total_items = len(items) * len(attacks)

    with tqdm(total=total_items, desc="Running experiments") as pbar:
        for attack_type in attacks:
            if attack_type == "baseline":
                attack = Baseline(logger, llm, auxiliary, config)
            elif attack_type == "dra":
                attack = DRA(logger, llm, auxiliary, config)
            elif attack_type == "art_prompt":
                attack = ArtPrompt(logger, llm, auxiliary, config)
            elif attack_type == "flip_attack":
                attack = FlipAttack(logger, llm, auxiliary, config)
            elif attack_type == "pair":
                attack = PAIR(logger, llm, auxiliary, config)
            elif attack_type == "sata":
                attack = SATA(logger, llm, auxiliary, config)
            elif attack_type == "pass":
                attack = PASS(logger, llm, auxiliary, config)

            result_file = os.path.join(output_dir, f"{attack_type}_results.jsonl")
            
            logger.info(f"Start to exp...[{llm_type}]-[{dataset}]-[{attack_type}]")
            
            successful_attacks = 0
            total_attacks = len(items)
            
            for item in items:
                try:
                    query = item['query']
                    target = item['target'] 
                    category = item['category']

                    result = attack.attack(query, target, category)
                    success = result['judge']['success_status']

                    with open(result_file, 'a', encoding='utf-8') as f:
                        json_str = json.dumps(result, ensure_ascii=False)
                        f.write(json_str + '\n')

                    if success:
                        successful_attacks += 1
                        
                except Exception as e:
                    logger.log_exception(e)
                
                pbar.update(1)
                pbar.set_description(f"Running {attack_type}")
            
            asr = successful_attacks / total_attacks
            logger.info(f"{attack_type} experiment completed. Final ASR: {asr:.4f}")

run_exp()
